import json
import os
from datetime import datetime
from pathlib import Path
from unittest.mock import patch

import pytest
from freezegun import freeze_time

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.source.bigquery_v2.bigquery_queries import (
    BigQueryQueriesSource,
    BigQueryQueriesSourceReport,
)
from datahub.metadata.urns import CorpUserUrn
from datahub.sql_parsing.sql_parsing_aggregator import ObservedQuery
from datahub.testing import mce_helpers
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
from tests.test_helpers.state_helpers import run_and_get_pipeline

FROZEN_TIME = "2024-08-19 07:00:00"
WINDOW_END_TIME = "2024-09-01 00:00:00Z"


def _generate_queries_cached_file(tmp_path: Path, queries_json_path: Path) -> None:
    # We choose to generate Cached audit log (FileBackedList backed by sqlite) at runtime
    # instead of using pre-existing sqlite file here as default serializer for FileBackedList
    # uses pickle which may not work well across python versions.

    shared_connection = ConnectionWrapper(tmp_path / "audit_log.sqlite")
    query_cache: FileBackedList[ObservedQuery] = FileBackedList(shared_connection)
    with open(queries_json_path, "r") as f:
        queries = json.load(f)
        assert isinstance(queries, list)
        for query in queries:
            query["timestamp"] = datetime.fromisoformat(query["timestamp"])
            query["user"] = (
                CorpUserUrn.from_string(query["user"]) if query["user"] else None
            )
            query_cache.append(ObservedQuery(**query))

        query_cache.close()
        shared_connection.close()


@freeze_time(FROZEN_TIME)
@patch("google.cloud.bigquery.Client")
@patch("google.cloud.resourcemanager_v3.ProjectsClient")
def test_queries_ingestion(project_client, client, pytestconfig, monkeypatch, tmp_path):
    test_resources_dir = pytestconfig.rootpath / "tests/integration/bigquery_v2"
    mcp_golden_path = f"{test_resources_dir}/bigquery_queries_mcps_golden.json"
    mcp_output_path = tmp_path / "bigquery_queries_mcps.json"

    try:
        # query_log.json is originally created by using queries dump generated by
        # acryl bigquery connector smoke test and using `datahub check extract-sql-agg-log`
        # command with tablename="data" to convert cached audit log to queries json followed by
        # a simple `acryl-staging`->`gcp-staging` replacement.

        _generate_queries_cached_file(tmp_path, test_resources_dir / "query_log.json")
    except Exception as e:
        pytest.fail(f"Failed to generate queries sqlite cache: {e}")

    pipeline_config_dict: dict = {
        "source": {
            "type": "bigquery-queries",
            "config": {
                "project_ids": ["gcp-staging", "gcp-staging-2"],
                "local_temp_path": tmp_path,
                "top_n_queries": 20,
                "window": {
                    "start_time": "-30d",
                    "end_time": WINDOW_END_TIME,
                },
            },
        },
        "sink": {"type": "file", "config": {"filename": str(mcp_output_path)}},
    }

    pipeline = run_and_get_pipeline(pipeline_config_dict)
    pipeline.pretty_print_summary()

    report = pipeline.source.get_report()
    assert isinstance(report, BigQueryQueriesSourceReport)
    assert report.queries_extractor is not None
    assert report.queries_extractor.sql_aggregator is not None
    assert report.queries_extractor.sql_aggregator.num_query_usage_stats_generated > 0
    assert (
        report.queries_extractor.sql_aggregator.num_query_usage_stats_outside_window
        == 0
    )

    mce_helpers.check_golden_file(
        pytestconfig,
        output_path=mcp_output_path,
        golden_path=mcp_golden_path,
    )


@patch("google.cloud.bigquery.Client")
@patch("google.cloud.resourcemanager_v3.ProjectsClient")
def test_source_close_cleans_tmp(projects_client, client, tmp_path):
    with patch("tempfile.tempdir", str(tmp_path)):
        source = BigQueryQueriesSource.create(
            {"project_ids": ["project1"]}, PipelineContext("run-id")
        )
        assert len(os.listdir(tmp_path)) > 0
        # This closes QueriesExtractor which in turn closes SqlParsingAggregator
        source.close()
        assert len(os.listdir(tmp_path)) == 0, (
            f"Files left in {tmp_path}: {os.listdir(tmp_path)}"
        )
